{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9d4a6c1c",
   "metadata": {},
   "source": [
    "# Assignment 3: Toxicity Classification in Online Comments "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e69fecf",
   "metadata": {},
   "source": [
    "### Process Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "eeab1cf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7820ac56",
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_dataset = pd.read_csv(\"dataset/dev_tfidf.csv\")\n",
    "train_dataset = pd.read_csv(\"dataset/train_tfidf.csv\")\n",
    "test_dataset = pd.read_csv(\"dataset/test_tfidf.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9e9b5d3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#split the feature and label set\n",
    "def split_f_l(df):\n",
    "    features = df.iloc[:, 26:].to_numpy()\n",
    "    label = df.iloc[:,1].to_numpy()\n",
    "    return features, label\n",
    "\n",
    "dev_features, dev_label = split_f_l(dev_dataset)\n",
    "train_features, train_label = split_f_l(train_dataset)\n",
    "test_features, _ = split_f_l(test_dataset)\n",
    "test_ids = test_dataset[\"ID\"].tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "124b42b7",
   "metadata": {},
   "source": [
    "### Supervised ML algorithm and Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3e94a1ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "\n",
    "def print_score(clf_name,classifier):\n",
    "    dev_preds = classifier.predict(dev_features)\n",
    "    precision_score, recall_score, fscore_score, _ = precision_recall_fscore_support(dev_label, dev_preds,average=\"weighted\")\n",
    "    \n",
    "    print(clf_name,\" Accuracy\", classifier.score(dev_features, dev_label))\n",
    "    print(clf_name,\" Precision score is \", precision_score)\n",
    "    print(clf_name,\" Recall score is \", recall_score)\n",
    "    print(clf_name,\" F1_score is \", fscore_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "88bb84cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Naive Bayes  Accuracy 0.6368\n",
      "Naive Bayes  Precision score is  0.7917125322109563\n",
      "Naive Bayes  Recall score is  0.6368\n",
      "Naive Bayes  F1_score is  0.6761213728499024\n"
     ]
    }
   ],
   "source": [
    "#naive bayes\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "\n",
    "NBclassifier = GaussianNB()\n",
    "NBclassifier.fit(train_features, train_label)\n",
    "\n",
    "print_score(\"Naive Bayes\",NBclassifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "00813ae7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logistic Regressions  Accuracy 0.8276\n",
      "Logistic Regressions  Precision score is  0.8015794457107356\n",
      "Logistic Regressions  Recall score is  0.8276\n",
      "Logistic Regressions  F1_score is  0.7894524512179725\n"
     ]
    }
   ],
   "source": [
    "#logistic regression\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "LRclassifier = LogisticRegression(random_state=66, max_iter=300)\n",
    "LRclassifier.fit(train_features, train_label)\n",
    "\n",
    "print_score(\"Logistic Regressions\",LRclassifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "375f0de5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decision Tree  Accuracy 0.7659333333333334\n",
      "Decision Tree  Precision score is  0.7494172521425234\n",
      "Decision Tree  Recall score is  0.7659333333333334\n",
      "Decision Tree  F1_score is  0.7567892463894531\n"
     ]
    }
   ],
   "source": [
    "#decision tree\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "DTclassifier = DecisionTreeClassifier(random_state=0)\n",
    "DTclassifier.fit(train_features, train_label)\n",
    "\n",
    "print_score(\"Decision Tree\",DTclassifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eb95e9f",
   "metadata": {},
   "source": [
    "### Baseline Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d50f74a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Zero Rule baseline accuracy  Accuracy 0.811\n",
      "Zero Rule baseline accuracy  Precision score is  0.657721\n",
      "Zero Rule baseline accuracy  Recall score is  0.811\n",
      "Zero Rule baseline accuracy  F1_score is  0.7263622308117064\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/homebrew/Cellar/jupyterlab/3.3.2/libexec/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1334: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "from sklearn.dummy import DummyClassifier\n",
    "Frequentc = DummyClassifier(strategy=\"most_frequent\")\n",
    "Frequentc.fit(train_features, train_label)\n",
    "\n",
    "print_score(\"Zero Rule baseline accuracy\",Frequentc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e0ad75f",
   "metadata": {},
   "source": [
    "### Research Question 2: Bias in different sub-group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "fc375942",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_sub_group(df):\n",
    "    christian_df = df[(df['Christian']==1)&(df['Muslim']==0)&(df['Female']==0)&(df['Homosexual gay or lesbian']==0)&(df['Male']==0)]\n",
    "    muslim_df = df[(df['Christian']==0)&(df['Muslim']==1)&(df['Female']==0)&(df['Homosexual gay or lesbian']==0)&(df['Male']==0)]\n",
    "    female_df = df[(df['Christian']==0)&(df['Muslim']==0)&(df['Female']==1)&(df['Homosexual gay or lesbian']==0)&(df['Male']==0)]\n",
    "    homosexual_df = df[(df['Christian']==0)&(df['Muslim']==0)&(df['Female']==0)&(df['Homosexual gay or lesbian']==1)&(df['Male']==0)]\n",
    "    male_df = df[(df['Christian']==0)&(df['Muslim']==0)&(df['Female']==0)&(df['Homosexual gay or lesbian']==0)&(df['Male']==1)]\n",
    "    return christian_df, muslim_df, female_df, homosexual_df, male_df\n",
    "#dev dataset by subgroup\n",
    "christian_df, muslim_df, female_df, homosexual_df, male_df = select_sub_group(dev_dataset)\n",
    "christian_features, christian_labels = split_f_l(christian_df)\n",
    "muslim_features, muslim_labels = split_f_l(muslim_df)\n",
    "female_features, female_labels = split_f_l(female_df)\n",
    "homosexual_features, homosexual_labels = split_f_l(homosexual_df)\n",
    "male_features, male_labels = split_f_l(male_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "b3a8f2a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_bias_variance(name,classifier):\n",
    "    dev_acc_scores = []\n",
    "\n",
    "    dev_acc_scores.append(classifier.score(christian_features, christian_labels))\n",
    "    print(name,\"accuracy in christian group \", dev_acc_scores[-1])\n",
    "\n",
    "    dev_acc_scores.append(classifier.score(muslim_features, muslim_labels))\n",
    "    print(name,\"accuracy in muslim group \", dev_acc_scores[-1])\n",
    "\n",
    "    dev_acc_scores.append(classifier.score(female_features, female_labels))\n",
    "    print(name,\"accuracy in female group \", dev_acc_scores[-1])\n",
    "\n",
    "    dev_acc_scores.append(classifier.score(homosexual_features, homosexual_labels))\n",
    "    print(name,\"accuracy in homosexual group \", dev_acc_scores[-1])\n",
    "\n",
    "    dev_acc_scores.append(classifier.score(male_features, male_labels))\n",
    "    print(name,\"accuracy in male group \", dev_acc_scores[-1])\n",
    "\n",
    "    print(name,\"variance is \", np.var(dev_acc_scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "4565ffb0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR accuracy in christian group  0.9309473684210526\n",
      "LR accuracy in muslim group  0.7858974358974359\n",
      "LR accuracy in female group  0.84930966469428\n",
      "LR accuracy in homosexual group  0.7716535433070866\n",
      "LR accuracy in male group  0.8247903075489282\n",
      "LR variance is  0.0031815953747023154\n"
     ]
    }
   ],
   "source": [
    "print_bias_variance(\"LR\",LRclassifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24abec85",
   "metadata": {},
   "source": [
    "### Solution1 : Class Balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "de8aaff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "##trainset by subgroup\n",
    "christian_tdf, muslim_tdf, female_tdf, homosexual_tdf, male_tdf = select_sub_group(train_dataset)\n",
    "christian_features2, christian_labels2 = split_f_l(christian_tdf)\n",
    "muslim_features2, muslim_labels2 = split_f_l(muslim_tdf)\n",
    "female_features2, female_labels2 = split_f_l(female_tdf)\n",
    "homosexual_features2, homosexual_labels2 = split_f_l(homosexual_tdf)\n",
    "male_features2, male_labels2= split_f_l(male_tdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "7d715275",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All ——Trainset Class distribution 0:1 = 5.226096237658988  Devset Class distribution 0:1 = 4.291005291005291\n",
      "----------------\n",
      "christian ——Trainset Class distribution 0:1 = 13.834617664493184  Devset Class distribution 0:1 = 12.268156424581006\n",
      "muslim ——Trainset Class distribution 0:1 = 3.5236749116607773  Devset Class distribution 0:1 = 3.262295081967213\n",
      "female ——Trainset Class distribution 0:1 = 6.406841783750764  Devset Class distribution 0:1 = 5.401515151515151\n",
      "homosexual ——Trainset Class distribution 0:1 = 2.645210727969349  Devset Class distribution 0:1 = 2.8484848484848486\n",
      "male ——Trainset Class distribution 0:1 = 5.427030913012222  Devset Class distribution 0:1 = 4.2727272727272725\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "def show_label_distribution(name,labels,labels2):\n",
    "    labels_counter = Counter(labels)\n",
    "    labels_counter2 = Counter(labels2)\n",
    "    print(name,\"——Trainset Class distribution 0:1 =\", float(labels_counter[0]) / float(labels_counter[1]),\n",
    "         \" Devset Class distribution 0:1 =\", float(labels_counter2[0]) / float(labels_counter2[1]))\n",
    "\n",
    "show_label_distribution(\"All\",train_label,dev_label)\n",
    "\n",
    "print(\"----------------\")\n",
    "\n",
    "show_label_distribution(\"christian\",christian_labels2,christian_labels)\n",
    "show_label_distribution(\"muslim\",muslim_labels2,muslim_labels)\n",
    "show_label_distribution(\"female\",female_labels2,female_labels)\n",
    "show_label_distribution(\"homosexual\",homosexual_labels2,homosexual_labels)\n",
    "show_label_distribution(\"male\",male_labels2,male_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "3eb47593",
   "metadata": {},
   "outputs": [],
   "source": [
    "christian_tdf, muslim_tdf, female_tdf, homosexual_tdf, male_tdf \n",
    "cdf1=christian_df[christian_df[\"Toxicity\"]==1]#positive sample\n",
    "cdf0=christian_df[christian_df[\"Toxicity\"]==0]#negative sample\n",
    "#down-sampling the negative sample\n",
    "cdf2=cdf0.sample(frac=0.89)\n",
    "#combine the positive and downsampled negative sample\n",
    "cdf_new=pd.concat([cdf1,cdf2])\n",
    "\n",
    "mdf1=muslim_tdf[muslim_tdf[\"Toxicity\"]==1]\n",
    "mdf0=muslim_tdf[muslim_tdf[\"Toxicity\"]==0]\n",
    "mdf2=mdf0.sample(frac=0.92)\n",
    "mdf_new=pd.concat([mdf1,mdf2])\n",
    "\n",
    "fdf1=female_df[female_df[\"Toxicity\"]==1]\n",
    "fdf0=female_df[female_df[\"Toxicity\"]==0]\n",
    "fdf2=fdf0.sample(frac=0.84)\n",
    "fdf_new=pd.concat([fdf1,fdf2])\n",
    "\n",
    "hdf1=homosexual_df[homosexual_df[\"Toxicity\"]==1]\n",
    "hdf0=homosexual_df[homosexual_df[\"Toxicity\"]==0]\n",
    "hdf2=hdf1.sample(frac=0.93)\n",
    "hdf_new=pd.concat([hdf0,hdf2])\n",
    "\n",
    "adf1=male_df[male_df[\"Toxicity\"]==1]\n",
    "adf0=male_df[male_df[\"Toxicity\"]==0]\n",
    "adf2=adf0.sample(frac=0.78)\n",
    "adf_new=pd.concat([adf1,adf2])\n",
    "\n",
    "train_newdf = pd.concat([cdf_new,mdf_new,fdf_new,hdf_new,adf_new])\n",
    "train_newfeatures = train_newdf.iloc[:, 26:].to_numpy()\n",
    "train_newlabel = train_newdf.iloc[:,1].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "d95c5110",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logistic Regressions with balanced label accuracy 0.8266666666666667\n",
      "Logistic Regressions with balanced label accuracy in christian group  0.9305263157894736\n",
      "Logistic Regressions with balanced label accuracy in muslim group  0.7884615384615384\n",
      "Logistic Regressions with balanced label accuracy in female group  0.855621301775148\n",
      "Logistic Regressions with balanced label accuracy in homosexual group  0.7779527559055118\n",
      "Logistic Regressions with balanced label accuracy in male group  0.8252562907735321\n",
      "Logistic Regressions with balanced label variance is  0.0030128165945094753\n"
     ]
    }
   ],
   "source": [
    "LRclassifier_balanced = LogisticRegression(random_state=66, max_iter=300)\n",
    "LRclassifier_balanced.fit(train_newfeatures, train_newlabel)\n",
    "\n",
    "print(\"Logistic Regressions with balanced label accuracy\", LRclassifier_balanced.score(dev_features, dev_label))\n",
    "print_bias_variance(\"Logistic Regressions with balanced label\",LRclassifier_balanced)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fff30804",
   "metadata": {},
   "source": [
    "### Solution 2 (Overfitting): Bagging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "cd6f39bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logistic Regression bagging accuracy 0.8274666666666667\n",
      "Logistic Regression bagging accuracy accuracy in christian group  0.9305263157894736\n",
      "Logistic Regression bagging accuracy accuracy in muslim group  0.7884615384615384\n",
      "Logistic Regression bagging accuracy accuracy in female group  0.84930966469428\n",
      "Logistic Regression bagging accuracy accuracy in homosexual group  0.7716535433070866\n",
      "Logistic Regression bagging accuracy accuracy in male group  0.8233923578751164\n",
      "Logistic Regression bagging accuracy variance is  0.003123241534944309\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import BaggingClassifier\n",
    "LRclassifier= LogisticRegression(random_state=66, max_iter=300)\n",
    "LRclassifier_bagging = BaggingClassifier(LRclassifier,n_estimators=6, random_state=6).fit(train_features, train_label)\n",
    "\n",
    "print(\"Logistic Regression bagging accuracy\", LRclassifier_bagging.score(dev_features, dev_label))\n",
    "print_bias_variance(\"Logistic Regression bagging accuracy\",LRclassifier_bagging)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0df0392d",
   "metadata": {},
   "source": [
    "### Feature enginnering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "ced3e31d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#read file\n",
    "train_df = pd.read_csv(\"dataset/train_raw.csv\")\n",
    "dev_df = pd.read_csv(\"dataset/dev_raw.csv\")\n",
    "test_df = pd.read_csv(\"dataset/test_raw.csv\")\n",
    "#split the feature and label\n",
    "train_raw_comment = train_df[\"Comment\"].tolist()\n",
    "train_label = train_df[\"Toxicity\"].to_numpy()\n",
    "dev_raw_comment = dev_df[\"Comment\"].tolist()\n",
    "dev_label = dev_df[\"Toxicity\"].to_numpy()\n",
    "test_raw_comment = test_df[\"Comment\"].tolist()\n",
    "\n",
    "christian_raw_df, muslim_raw_df, female_raw_df, homosexual_raw_df, male_raw_df = select_sub_group(dev_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "09aa3cde",
   "metadata": {},
   "outputs": [],
   "source": [
    "#preprocessing\n",
    "#remove punctuation\n",
    "import string\n",
    "#defining the function to remove punctuation\n",
    "def remove_punctuation(text):\n",
    "    punctuationfree=\"\".join([i for i in text if i not in string.punctuation])\n",
    "    return punctuationfree\n",
    "#storing the puntuation free text\n",
    "train_removed_pun = []\n",
    "for sen in train_raw_comment:\n",
    "    temp = remove_punctuation(sen).lower()\n",
    "    train_removed_pun.append(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "234db43c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#data preprocessing : bow\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "\n",
    "bow_vectorizer = CountVectorizer(stop_words='english')\n",
    "train_bow = bow_vectorizer.fit_transform(train_raw_comment)\n",
    "dev_bow = bow_vectorizer.transform(dev_raw_comment)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "80565dec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_comment_label(df):\n",
    "    comment = df[\"Comment\"].tolist()\n",
    "    label =  df[\"Toxicity\"].to_numpy()\n",
    "    return comment, label\n",
    "def print_bias_variance_bow(name,classifier):\n",
    "    dev_acc_scores = []\n",
    "\n",
    "    christian_comment, christian_label = split_comment_label(christian_raw_df)\n",
    "    christian_bow = bow_vectorizer.transform(christian_comment)\n",
    "    dev_acc_scores.append(classifier.score(christian_bow, christian_label))\n",
    "    print(name,\"accuracy in christian group \", classifier.score(christian_bow, christian_label))\n",
    "\n",
    "    muslim_comment, muslim_label = split_comment_label(muslim_raw_df)\n",
    "    muslim_bow = bow_vectorizer.transform(muslim_comment)\n",
    "    dev_acc_scores.append(classifier.score(muslim_bow, muslim_label))\n",
    "    print(name,\"accuracy in muslim group \", classifier.score(muslim_bow, muslim_label))\n",
    "\n",
    "    female_comment, female_label = split_comment_label(female_raw_df)\n",
    "    female_bow = bow_vectorizer.transform(female_comment)\n",
    "    dev_acc_scores.append(classifier.score(female_bow, female_label))\n",
    "    print(name,\"accuracy in female group \", classifier.score(female_bow, female_label))\n",
    "\n",
    "    homosexual_comment, homosexual_label = split_comment_label(homosexual_raw_df)\n",
    "    homosexual_bow = bow_vectorizer.transform(homosexual_comment)\n",
    "    dev_acc_scores.append(classifier.score(homosexual_bow, homosexual_label))\n",
    "    print(name,\"accuracy in homosexual group \", classifier.score(homosexual_bow, homosexual_label))\n",
    "\n",
    "    male_comment, male_label = split_comment_label(male_raw_df)\n",
    "    male_bow = bow_vectorizer.transform(male_comment)\n",
    "    dev_acc_scores.append(classifier.score(male_bow, male_label))\n",
    "    print(name,\"accuracy in male group \", classifier.score(male_bow, male_label))\n",
    "\n",
    "    print(name,\"variance is \", np.var(dev_acc_scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "ca184b95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logistic Regressions bow with balanced label accuracy 0.8424666666666667\n",
      "Logistic Regressions bow with balanced label accuracy in christian group  0.9334736842105263\n",
      "Logistic Regressions bow with balanced label accuracy in muslim group  0.7846153846153846\n",
      "Logistic Regressions bow with balanced label accuracy in female group  0.8733727810650888\n",
      "Logistic Regressions bow with balanced label accuracy in homosexual group  0.7763779527559055\n",
      "Logistic Regressions bow with balanced label accuracy in male group  0.8420316868592731\n",
      "Logistic Regressions bow with balanced label variance is  0.0033901862326438286\n"
     ]
    }
   ],
   "source": [
    "LRclassifier_bow = LogisticRegression(random_state=666, max_iter=800)\n",
    "LRclassifier_bow.fit(train_bow, train_label)\n",
    "print(\"Logistic Regressions bow with balanced label accuracy\", LRclassifier_bow.score(dev_bow, dev_label))\n",
    "print_bias_variance_bow(\"Logistic Regressions bow with balanced label\",LRclassifier_bow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1c81c96",
   "metadata": {},
   "source": [
    "### Final model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "aa892178",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LRbagging, bow with balanced label accuracy accuracy 0.8196666666666667\n"
     ]
    }
   ],
   "source": [
    "LRclassifier_bow = LogisticRegression(random_state=666, max_iter=800, class_weight=\"balanced\")\n",
    "LRclassifier_final = BaggingClassifier(LRclassifier_bow,n_estimators=10, random_state=0).fit(train_bow, train_label)\n",
    "\n",
    "print(\"LRbagging, bow with balanced label accuracy accuracy\", LRclassifier_final.score(dev_bow, dev_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "d955eea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_bow = bow_vectorizer.transform(test_raw_comment)\n",
    "\n",
    "test_preds = LRclassifier_final.predict(test_bow).tolist()\n",
    "assert len(test_ids) == len(test_preds)\n",
    "\n",
    "f = open(\"test_predictions.csv\", \"w\")\n",
    "f.write(\"ID,Toxicity\\n\")\n",
    "for test_id, test_pred in zip(test_ids, test_preds):\n",
    "    f.write(str(test_id) + \",\" +str(test_pred) + \"\\n\")\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e53d30",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
